-
Notifications
You must be signed in to change notification settings - Fork 135
Implement vectorized jacobian and allow arbitrary expression shapes #1228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
17ec0cd
to
3c6ba6a
Compare
3c6ba6a
to
ff732d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements an enhanced jacobian computation that supports vectorization and arbitrary expression shapes, yielding improved performance while preserving compatibility via a fallback scan method. Key changes include:
- Adding a new vectorized branch to the jacobian function in pytensor/gradient.py along with new tests covering scalar, vector, and matrix cases.
- Updating tests in tests/test_gradient.py to use a parameterized test class for the new vectorize functionality.
- Minor type hint and function signature adjustments in pytensor/tensor/basic.py and pytensor/graph/replace.py.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
File | Description |
---|---|
tests/test_gradient.py | New parameterized tests for the jacobian function with the added vectorize flag. |
pytensor/tensor/basic.py | Added an early return in flatten when the reshaping yields the same number of dimensions. |
pytensor/graph/replace.py | Updated return type annotations for enhanced clarity and consistency. |
pytensor/gradient.py | Extended the jacobian function with a vectorize branch and adjusted inner gradient handling. |
Comments suppressed due to low confidence (3)
pytensor/gradient.py:2094
- Using zip with strict=True requires Python 3.10 or later. Please confirm that this requirement is acceptable for the project or consider alternative implementations for compatibility.
for i, (jacobian_single_row, jacobian_matrix) in enumerate(zip(jacobian_single_rows, jacobian_matrices, strict=True)):
pytensor/gradient.py:2104
- [nitpick] In the non-vectorized branch of jacobian, the inner function unpacks arguments to compute grad(expr[idx], wrt, **grad_kwargs). Please verify that passing 'wrt' as a list directly matches the intended grad() API to avoid potential issues with multidimensional expressions.
idx, expr, *wrt = args
pytensor/gradient.py:2322
- The change from g_out.zeros_like() to g_out.zeros_like(g_out) is unexpected compared to previous usage. Please ensure that the new call properly infers the shape and dtype without introducing recursion or inconsistency.
return [g_out.zeros_like(g_out) for g_out in g_outs]
1764707
to
e2a2665
Compare
Also allow arbitrary expression dimensionality
e2a2665
to
7f161ae
Compare
We can follow up with a vectorized Hessian, even though @aseyboldt doesn't believe anyone ever needs them. Allowing arbitrary expression shapes is not as trivial. JAX returns some nested tuple stuff... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to add a benchmark of scan vs vectorize?
respect to some parameter ``x`` we need to use `scan`. What we | ||
do is to loop over the entries in ``y`` and compute the gradient of | ||
respect to some parameter ``x`` we can use `scan`. | ||
We loop over the entries in ``y`` and compute the gradient of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We loop over the entries in ``y`` and compute the gradient of | |
In this case, we loop over the entries in ``y`` and compute the gradient of |
``y[i]`` with respect to ``x``. | ||
|
||
.. note:: | ||
|
||
`scan` is a generic op in PyTensor that allows writing in a symbolic | ||
manner all kinds of recurrent equations. While creating | ||
symbolic loops (and optimizing them for performance) is a hard task, | ||
effort is being done for improving the performance of `scan`. We | ||
shall return to :ref:`scan<tutloop>` later in this tutorial. | ||
effort is being done for improving the performance of `scan`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
effort is being done for improving the performance of `scan`. | |
efforts are being made to improving the performance of `scan`. |
>>> from pytensor.graph import vectorize_graph | ||
>>> x = pt.dvector('x') | ||
>>> y = x ** 2 | ||
>>> row_tangent = pt.dvector("row_tangent") # Helper variable, it will be replaced during vectorization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The term I think gets used is cotangent_vector
?
@@ -2051,62 +2057,73 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise | |||
output, then a zero variable is returned. The return value is | |||
of same type as `wrt`: a list/tuple or TensorVariable in all cases. | |||
""" | |||
from pytensor.tensor import broadcast_to, eye |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the actual source for the import? tensor.basic
or tensor.shape
I guess?
) | ||
|
||
amat = matrix() | ||
amat_val = random(4, 5) | ||
for ndim in (2, 1): | ||
for ndim in (1,): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove useless loop
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (89.28%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1228 +/- ##
==========================================
- Coverage 82.03% 82.03% -0.01%
==========================================
Files 214 214
Lines 50398 50408 +10
Branches 8897 8902 +5
==========================================
+ Hits 41345 41352 +7
- Misses 6848 6850 +2
- Partials 2205 2206 +1
🚀 New features to boost your workflow:
|
Seeing about a 3.5x, (2x for larger x) time speedup for this trivial case:
Memory footprint will grow though, specially if intermediate operations are much larger than the final jacobian. Also not all graphs will be safely vectorizable, so I would leave the Scan option as a default for a while.
Related Issue